Define the model:
model_str <- "
model {
# Define the priors
#lambda ~ dgamma(1, 1) # average number of new arrivals
#rho ~ dbeta(1, 1) # detection probability
#delta ~ dbeta(1, 1) # survival probability
# Initialize the model for k = 1
m[1] ~ dpois(lambda) # number of new arrivals
n[1] <- m[1] # abundance
y[1] ~ dbin(rho, n[1]) # number of observed individuals
# Define the model for k > 1
for (k in 2:K) {
z[k - 1] ~ dbin(delta, n[k - 1])
m[k] ~ dpois(lambda)
n[k] <- m[k] + z[k - 1]
y[k] ~ dbin(rho, n[k])
}
}
"
model_conn <- textConnection(model_str)
Run JAGS and obtain samples:
library(rjags)
# Libraries for plotting
library(dplyr)
library(tidyr)
library(lazyeval)
library(ggplot2)
# Define data
K <- 10
n_iter <- 10000
lambda <- 10 # rate of new arrivals
delta <- 0.5 # survival probability
rho <- 0.8 # detection probability
y <- c(10, 11, 7, 15, 13, 17, 13, 13, 18, 15)
# Initialize values
m0 <- y
# Run JAGS
model <- jags.model(model_conn,
data = list('y' = y, 'K' = K, 'lambda' = lambda,
'delta' = delta, 'rho' = rho),
inits = list('m' = m0), # 'z' = z0
n.chains = 1,
n.adapt = 0)
t <- system.time(
samples <- coda.samples(model, c('n', 'm', 'z'), n.iter = n_iter, thin = 20)
)
Plot the cumulative mean vs iteration for each variable n, m and z at each time step k:
# Turn mcmc object into a dataframe
t_vec <- cumsum(rep(t[3] / nrow(samples[[1]]), nrow(samples[[1]])))
samples_df <- data.frame(samples[[1]], iter = 1:nrow(samples[[1]]), time = t_vec)
# Remove periods from column names
colnames(samples_df) <- gsub('\\.', '', colnames(samples_df))
# Preprocess the data frame for ggplot
samples_df <- samples_df %>%
gather(var_k, value, -iter, -time) %>%
separate(var_k, into = c('var', 'k'), sep = 1) %>%
group_by(var, k) %>%
mutate(cum_mean = cumsum(value) / iter)
samples_df$k <- factor(samples_df$k, levels = 1:K)
# Plot each variable
for (var in c('n', 'm', 'z')) {
samples_var <- samples_df %>%
filter_(interp(~ var == v, v = var))
plt <- ggplot(samples_var, aes(iter, cum_mean, colour = k)) +
geom_line() +
theme_bw() +
ggtitle(paste('Mean vs iteration of', var))
print(plt)
}
Discard samples from iteration \(\leq\) 100:
samples_df_new <- samples_df %>% filter(iter > 100) %>%
mutate(cum_mean = cumsum(value) / (iter - 100))
Compare the new sample means of n with its true means:
true_means <- c(11.8206, 13.6123, 11.0323, 17.8103, 17.1073, 20.2344, 17.0239, 17.1461, 21.3691, 19.2281)
samples_n <- samples_df_new %>%
filter(var == 'n')
# Absolute difference at each time step
summary <- samples_n %>% filter(iter == max(iter)) %>%
select(var, k, sample_mean = cum_mean)
summary$true_mean <- true_means
summary$diff <- abs(summary$sample_mean - summary$true_mean)
knitr::kable(summary)
| var | k | sample_mean | true_mean | diff |
|---|---|---|---|---|
| n | 1 | 11.8725 | 11.8206 | 0.0519 |
| n | 2 | 13.6175 | 13.6123 | 0.0052 |
| n | 3 | 10.9675 | 11.0323 | 0.0648 |
| n | 4 | 17.8475 | 17.8103 | 0.0372 |
| n | 5 | 17.0925 | 17.1073 | 0.0148 |
| n | 6 | 20.2375 | 20.2344 | 0.0031 |
| n | 7 | 17.0200 | 17.0239 | 0.0039 |
| n | 8 | 17.1400 | 17.1461 | 0.0061 |
| n | 9 | 21.5250 | 21.3691 | 0.1559 |
| n | 10 | 19.1725 | 19.2281 | 0.0556 |
# MSE vs running time
sample_means <- matrix(samples_n$cum_mean, nrow = length(unique(samples_n$iter)))
mse <- rowSums(sweep(sample_means, 2, true_means, '-')^2)/length(true_means)
mse_df <- data.frame(mse, time = sort(unique(samples_n$time)))
ggplot(mse_df, aes(time, mse)) +
geom_line() +
theme_bw() +
xlab('time in seconds') +
ggtitle('MSE vs running time')
Plot the distribution of n at each time step:
samples_n <- samples_df_new %>%
filter(var == 'n')
true_marg <- read.csv('../data/true_marg.csv', header = FALSE)
for (k in 1:K) {
samples_nk <- samples_n %>%
filter_(interp(~ k == t_step, t_step = k))
# Histogram of samples
x_range <- min(samples_nk$value):max(samples_nk$value)
hist_plt_title <- bquote("Sample count of" ~ n[.(k)])
hist_plt <- ggplot(samples_nk, aes(value)) +
geom_bar() +
scale_x_continuous(breaks = x_range) +
theme_bw() +
ggtitle(hist_plt_title)
print(hist_plt)
# Density of samples
dist_plt_title <- bquote("Sample posterior marginal of" ~ n[.(k)])
dist_plt <- ggplot(samples_nk, aes(value, ..density..)) +
geom_freqpoly(binwidth = 1, na.rm = TRUE) +
geom_vline(aes(xintercept = mean(value)),
color = 'red', linetype = 'dashed') +
xlim(0, 40) +
theme_bw() +
ggtitle(dist_plt_title)
print(dist_plt)
# MSE vs running time
sample_marg <- matrix(0, 40, nrow(samples_nk))
n <- samples_nk$value
sample_marg[nrow(sample_marg) * ((1:length(n)) - 1) + n + 1] <- 1
sample_marg <- t(sample_marg)
sample_marg <- apply(sample_marg, 2, cumsum)
norm <- apply(sample_marg, 1, sum)
sample_marg <- sample_marg / norm
mse_plt_title <- bquote("MSE vs running time of" ~ n[.(k)])
mse <- rowSums(sweep(sample_marg, 2, true_marg[, k], '-')^2)/length(true_marg)
mse_df <- data.frame(mse, time = samples_nk$time)
mse_plt <- ggplot(mse_df, aes(time, mse)) +
geom_line() +
theme_bw() +
xlab('time in seconds') +
ggtitle(mse_plt_title)
print(mse_plt)
}
params <- c('lambda', 'delta', 'rho')
true_params <- c(lambda, delta, rho)
names(true_params) <- params
# PGF estimates
pgf_params <- c(9.566, 0.48, 0.8858)
names(pgf_params) <- params
compare_params <- data.frame(param = names(true_params), value = true_params, type = 'true')
compare_params <- rbind(compare_params, data.frame(param = names(pgf_params), value = pgf_params, type = 'pgf'))
# Load samples
samples_df <- read.csv('../data/param_est/converge/samples_conv_iter1m_thin100_wo_burnin.csv',
stringsAsFactors = FALSE)
# Scatter plot of rho vs lambda
wide_df <- samples_df %>%
select(-cum_mean, -cum_mode) %>%
spread(param, value)
ggplot(wide_df, aes(lambda, rho)) +
geom_point() +
ggtitle('Scatter plot of rho vs lambda')
# Preprocess df for ggplot
est_df <- samples_df %>% na.omit() %>%
gather(estimator, estimate, -iter, -time, -param, -chain, -value) %>%
select(time, chain, param, estimator, estimate) %>%
mutate(time = time / 60)
den_df <- samples_df %>% select(chain, param, value)
for (param in params) {
est_param <- est_df %>%
filter_(interp(~ param == p, p = param)) %>%
mutate(diff = abs(estimate - true_params[param]))
den_param <- den_df %>%
filter_(interp(~ param == p, p = param))
compare_param <- compare_params %>% filter_(interp(~ param == p, p = param))
# Plot mean and mode vs running time
est_plt <- ggplot(est_param,
aes(time, estimate, colour = estimator, linetype = as.factor(chain))) +
geom_line() +
#geom_hline(aes(yintercept = pgf_params[param]),
# color = 'red', linetype = 'dashed') +
geom_hline(data = compare_param,
aes(yintercept=value,
linetype='dashed',
colour = type),
show_guide = TRUE) +
theme_bw() +
xlab('time in hours') +
ggtitle(paste('Estimate vs running time of', param))
print(est_plt)
# Plot error vs running time
error_plt <- ggplot(est_param,
aes(time, diff, colour = estimator, linetype = as.factor(chain))) +
geom_line() +
theme_bw() +
xlab('time in hours') +
ggtitle(paste('Error vs running time of', param))
print(error_plt)
# Plot density of samples
d_plt <- ggplot(den_param, aes(value, linetype = as.factor(chain))) +
geom_density(bw = 'SJ') +
geom_vline(aes(xintercept = true_params[param]),
color = 'red', linetype = 'dashed') +
theme_bw() +
ggtitle(paste('Density of', param))
print(d_plt)
#tmp <- mcmc.list(mcmc((den_param %>% filter(chain == 1))$value),
# mcmc((den_param %>% filter(chain == 2))$value))
#print(HPDinterval(tmp, 0.95))
}
## Warning: `show_guide` has been deprecated. Please use `show.legend`
## instead.
## Warning: `show_guide` has been deprecated. Please use `show.legend`
## instead.
## Warning: `show_guide` has been deprecated. Please use `show.legend`
## instead.
summary <- samples_df %>%
filter(iter == max(iter), param != 'lambda_rho') %>%
select(chain, param, cum_mean, cum_mode) %>%
mutate(error_mean = abs(cum_mean - true_params[param]),
error_mode = abs(cum_mode - true_params[param]))
knitr::kable(summary)
| chain | param | cum_mean | cum_mode | error_mean | error_mode |
|---|---|---|---|---|---|
| 1 | delta | 0.4801563 | 0.4791839 | 0.0198437 | 0.0208161 |
| 1 | lambda | 9.7000552 | 9.5099003 | 0.2999448 | 0.4900997 |
| 1 | rho | 0.8755396 | 0.8687355 | 0.0755396 | 0.0687355 |